"""
Parts of this code are from torchvision and thus licensed under

BSD 3-Clause License

Copyright (c) Soumith Chintala 2016, 
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
  list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice,
  this list of conditions and the following disclaimer in the documentation
  and/or other materials provided with the distribution.

* Neither the name of the copyright holder nor the names of its
  contributors may be used to endorse or promote products derived from
  this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

"""

from __future__ import division

import math
from typing import Sequence, TypeVar

import torch
from torch.jit.annotations import List, Tuple
from torch import Tensor

from torchvision.models.detection._utils import BoxCoder


@torch.jit.script
def encode_boxes(reference_boxes: torch.Tensor,
                 proposals: torch.Tensor,
                 weights: torch.Tensor,
                 ) -> torch.Tensor:
    """
    Encode a set of proposals with respect to some reference boxes

    Args:
        reference_boxes: reference boxes (x1, y1, x2, y2, (z1, z2))
        proposals: boxes to be encoded (x1, y1, x2, y2, (z1, z2))
        weights: weights for dimensions (wx, wy, ww, wh, wz, wd)
    """
    # perform some unpacking to make it JIT-fusion friendly
    wx = weights[0]
    wy = weights[1]
    ww = weights[2]
    wh = weights[3]

    proposals_x1 = proposals[:, 0].unsqueeze(1)
    proposals_y1 = proposals[:, 1].unsqueeze(1)
    proposals_x2 = proposals[:, 2].unsqueeze(1)
    proposals_y2 = proposals[:, 3].unsqueeze(1)

    reference_boxes_x1 = reference_boxes[:, 0].unsqueeze(1)
    reference_boxes_y1 = reference_boxes[:, 1].unsqueeze(1)
    reference_boxes_x2 = reference_boxes[:, 2].unsqueeze(1)
    reference_boxes_y2 = reference_boxes[:, 3].unsqueeze(1)

    # implementation starts here
    ex_widths = proposals_x2 - proposals_x1
    ex_heights = proposals_y2 - proposals_y1
    ex_ctr_x = proposals_x1 + 0.5 * ex_widths
    ex_ctr_y = proposals_y1 + 0.5 * ex_heights

    gt_widths = reference_boxes_x2 - reference_boxes_x1
    gt_heights = reference_boxes_y2 - reference_boxes_y1
    gt_ctr_x = reference_boxes_x1 + 0.5 * gt_widths
    gt_ctr_y = reference_boxes_y1 + 0.5 * gt_heights

    targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths
    targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights
    targets_dw = ww * torch.log(gt_widths / ex_widths)
    targets_dh = wh * torch.log(gt_heights / ex_heights)

    if proposals.shape[1] == 6:
        wz = weights[4]
        wd = weights[5]

        proposals_z1 = proposals[:, 4].unsqueeze(1)
        proposals_z2 = proposals[:, 5].unsqueeze(1)
        ex_depth = proposals_z2 - proposals_z1
        ex_ctr_z = proposals_z1 + 0.5 * ex_depth

        reference_boxes_z1 = reference_boxes[:, 4].unsqueeze(1)
        reference_boxes_z2 = reference_boxes[:, 5].unsqueeze(1)
        gt_depth = reference_boxes_z2 - reference_boxes_z1
        gt_ctr_z = reference_boxes_z1 + 0.5 * gt_depth

        targets_dz = wz * (gt_ctr_z - ex_ctr_z) / ex_depth
        targets_dd = wd * torch.log(gt_depth / ex_depth)

        targets = torch.cat((targets_dx, targets_dy, targets_dw, targets_dh,
                             targets_dz, targets_dd), dim=1)
    else:
        targets = torch.cat((targets_dx, targets_dy, targets_dw, targets_dh), dim=1)
    return targets


def decode_single(rel_codes: Tensor, boxes: Tensor,
                  weights: Sequence[float],
                  bbox_xform_clip: float) -> Tensor:
    """
    From a set of original boxes and encoded relative box offsets,
    get the decoded boxes.

    Args:
        rel_codes: encoded boxes [Num_boxes x (dim * 2)] (dx, dy, dw, dh, dz, dd)
        boxes: reference boxes (x1, y1, x2, y2, (z1, z2))
    """
    # offset is 4 in case of 2d data and 6 in case of 3d
    offset = boxes.shape[1]
    boxes = boxes.to(rel_codes.dtype)

    widths = boxes[:, 2] - boxes[:, 0]
    heights = boxes[:, 3] - boxes[:, 1]
    ctr_x = boxes[:, 0] + 0.5 * widths
    ctr_y = boxes[:, 1] + 0.5 * heights

    wx = weights[0]
    wy = weights[1]
    ww = weights[2]
    wh = weights[3]

    dx = rel_codes[:, 0::offset] / wx
    dy = rel_codes[:, 1::offset] / wy
    dw = rel_codes[:, 2::offset] / ww
    dh = rel_codes[:, 3::offset] / wh

    # Prevent sending too large values into torch.exp()
    dw = torch.clamp(dw, max=bbox_xform_clip)
    dh = torch.clamp(dh, max=bbox_xform_clip)

    pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
    pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
    pred_w = torch.exp(dw) * widths[:, None]
    pred_h = torch.exp(dh) * heights[:, None]

    pred_boxes1 = pred_ctr_x - torch.tensor(0.5, dtype=pred_ctr_x.dtype) * pred_w
    pred_boxes2 = pred_ctr_y - torch.tensor(0.5, dtype=pred_ctr_y.dtype) * pred_h
    pred_boxes3 = pred_ctr_x + torch.tensor(0.5, dtype=pred_ctr_x.dtype) * pred_w
    pred_boxes4 = pred_ctr_y + torch.tensor(0.5, dtype=pred_ctr_y.dtype) * pred_h

    if offset == 6:
        depths = boxes[:, 5] - boxes[:, 4]
        ctr_z = boxes[:, 4] + 0.5 * depths

        wz = weights[4]
        wd = weights[5]

        dz = rel_codes[:, 4::offset] / wz
        dd = rel_codes[:, 5::offset] / wd
        dd = torch.clamp(dd, max=bbox_xform_clip)

        pred_ctr_z = dz * depths[:, None] + ctr_z[:, None]
        pred_z = torch.exp(dd) * depths[:, None]

        pred_boxes5 = pred_ctr_z - torch.tensor(0.5, dtype=pred_ctr_z.dtype) * pred_z
        pred_boxes6 = pred_ctr_z + torch.tensor(0.5, dtype=pred_ctr_z.dtype) * pred_z
        pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4,
                                  pred_boxes5, pred_boxes6), dim=2).flatten(1)
    else:
        pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4),
                                 dim=2).flatten(1)
    return pred_boxes


class BoxCoderND(BoxCoder):
    """
    This class encodes and decodes a set of bounding boxes into
    the representation used for training the regressors.
    Compatible with 2d and 3d
    """
    def encode(self,
               reference_boxes: List[Tensor], 
               proposals: List[Tensor],
               ) -> Tuple[Tensor]:
        """
        Encode a set of proposals with respect to some reference boxes

        Args:
            reference_boxes: reference boxes for each image.
                (x1, y1, x2, y2, (z1, z2))
            proposals: proposals for each image
                (x1, y1, x2, y2, (z1, z2))

        Returns:
            Tuple[Tensor]: regression targets for each image
        """
        # filter for images which have a foreground class
        filter_min_one_gt = [rb.numel() > 0 for rb in reference_boxes]
        filtered_ref_boxes = [
            rb for idx, rb in enumerate(reference_boxes) if filter_min_one_gt[idx]]
        filtered_proposals = [
            pr for idx, pr in enumerate(proposals) if filter_min_one_gt[idx]]

        if any(filter_min_one_gt):
            filtered_encoded = super().encode(filtered_ref_boxes, filtered_proposals)

        # fill image with no ground truth
        idx_enc = 0
        encoded = []
        for img_idx, gt_present in enumerate(filter_min_one_gt):
            if gt_present:
                encoded.append(filtered_encoded[idx_enc])
                idx_enc += 1
            else:
                # fill with zeros because they  do not contribute to the
                # regression loss anyway (all anchors are labeled as background)
                encoded.append(torch.zeros_like(proposals[img_idx]))
        return encoded

    def encode_single(self,
                      reference_boxes: Tensor,
                      proposals: Tensor,
                      ) -> Tensor:
        """
        Encode a set of proposals with respect to some reference boxes

        Arguments:
            reference_boxes: reference boxes  (x1, y1, x2, y2, (z1, z2))
            proposals: boxes to be encoded  (x1, y1, x2, y2, (z1, z2))
        """
        dtype, device = reference_boxes.dtype, reference_boxes.device
        weights = torch.tensor(self.weights, dtype=dtype, device=device)
        targets = encode_boxes(reference_boxes, proposals, weights)
        return targets

    def decode(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor:
        """
        Decode boxes

        Args:
            rel_codes: relative offsets to reference boxes 
                (dx, dy, dw, dh, (dz, dd))[N, dim * 2]
            boxes: list of reference boxes per image
                (x1, y1, x2, y2, (z1, z2))

        Returns:
            Tensor: decoded boxes
        """
        assert isinstance(boxes, (list, tuple))
        assert isinstance(rel_codes, torch.Tensor)
        boxes_per_image = [b.size(0) for b in boxes]
        concat_boxes = torch.cat(boxes, dim=0)
        spatial_dims = concat_boxes.shape[1]
        box_sum = 0
        for val in boxes_per_image:
            box_sum += val
        pred_boxes = self.decode_single(rel_codes.reshape(box_sum, -1), concat_boxes)
        return pred_boxes.reshape(box_sum, spatial_dims)

    def decode_single(self, rel_codes: torch.Tensor, boxes: torch.Tensor):
        dtype, device = rel_codes.dtype, rel_codes.device
        return decode_single(rel_codes, boxes, self.weights, self.bbox_xform_clip)


CoderType = TypeVar('CoderType', bound=BoxCoderND)
